# -*- coding: utf-8 -*-
from __future__ import print_function
"""
Created on Tue Oct 23 14:03:51 2018

@author: xingshuli
"""
# from keras.applications.imagenet_utils import _obtain_input_shape
import keras
from keras_applications.imagenet_utils import _obtain_input_shape

from keras.engine.topology import get_source_inputs
from keras.layers import Input
from keras.layers import Conv2D
from keras.layers import Activation
from keras.layers import Dense
from keras.layers import BatchNormalization
from keras.layers import Concatenate
# from keras.applications.mobilenet import DepthwiseConv2D
from keras.layers import DepthwiseConv2D

from keras.layers import Lambda
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D
from keras.models import Model
from keras import backend as K

import numpy as np

class ShuffleNet_V2(object):

    def channel_split(self, x, name = ''):
        in_channels = x.shape.as_list()[-1] # 将最后一维数值拿出，拆分成两份
        ip = in_channels // 2
        # lambda是一种函数的写法，https://www.cnblogs.com/caizhao/p/7905094.html
        # c_hat 与 c 各取一半通道的特征图（最后一维）
        c_hat = Lambda(lambda z: z[:, :, :, 0:ip])(x)
        c = Lambda(lambda z: z[:, :, :, ip:])(x)

        return c_hat, c


    def channel_shuffle(self, x):
        height, width, channels = x.shape.as_list()[1:]  # 取出图片尺寸和通道数
        channels_per_split = channels // 2  # 拆分两份
        # -1代表不确定，如果整体点数确定。reshape函数 应该 会自动根据后面的参数，最后分配-1处的具体值
        # 下面三步将通道以2为一组，均匀分配（因为这里拆分成两份，决定了一组内通道间的间隔数）
        # 如[0 1 2 3 4 5 6 7]  转为 [0 4 1 5 2 6 3 7]
        x = K.reshape(x, [-1, height, width, 2, channels_per_split])
        x = K.permute_dimensions(x, (0,1,2,4,3))
        x = K.reshape(x, [-1, height, width, channels])

        return x

    def _shuffle_unit(self, inputs, out_channels, strides = 2, stage = 1, block = 1):

        bn_axis = -1
        prefix = 'stage%d/block%d' %(stage, block)

        branch_channels = out_channels // 2

        # 这里将输入分为两条路训练，两条路的输入都是原先的输入张量
        if strides == 2:
            x_1 = DepthwiseConv2D(kernel_size = 3, strides = 2, padding = 'same',
                                  use_bias = False, name = '%s/3x3dwconv_1' % prefix)(inputs)
            x_1 = BatchNormalization(axis = bn_axis, name = '%s/bn_3x3dwconv_1' % prefix)(x_1)
            x_1 = Conv2D(filters = branch_channels, kernel_size = 1, strides = 1, padding = 'same',
                         use_bias = False, name = '%s/1x1conv_1' % prefix)(x_1)
            x_1 = BatchNormalization(axis = bn_axis, name = '%s/bn_1x1conv_1' % prefix)(x_1)
            x_1 = Activation('relu')(x_1)

            x_2 = Conv2D(filters = branch_channels, kernel_size = 1, strides = 1, padding = 'same',
                         use_bias = False, name = '%s/1x1conv_2' % prefix)(inputs)
            x_2 = BatchNormalization(axis = bn_axis, name = '%s/bn_1x1conv_2' % prefix)(x_2)
            x_2 = Activation('relu')(x_2)
            x_2 = DepthwiseConv2D(kernel_size = 3, strides = 2, padding = 'same',
                                  use_bias = False, name = '%s/3x3dwconv_2' % prefix)(x_2)
            x_2 = BatchNormalization(axis = bn_axis, name = '%s/bn_3x3dwconv_2' % prefix)(x_2)
            x_2 = Conv2D(filters = branch_channels, kernel_size = 1, strides = 1, padding = 'same',
                         use_bias = False, name = '%s/1x1conv_3' % prefix)(x_2)
            x_2 = BatchNormalization(axis = bn_axis, name = '%s/bn_1x1conv_3' % prefix)(x_2)
            x_2 = Activation('relu')(x_2)

            x = Concatenate(axis = bn_axis, name = '%s/concat' % prefix)([x_1, x_2])

        # 这里将输入先以最后一维的通道分开，然后各自分为两条路训练
        if strides == 1:
            c_hat, c = self.channel_split(inputs, name = '%s/split' % prefix)

            c = Conv2D(filters = branch_channels, kernel_size = 1, strides = 1, padding = 'same',
                       use_bias = False, name = '%s/1x1conv_4' % prefix)(c)
            c = BatchNormalization(axis = bn_axis, name = '%s/bn_1x1conv_4' % prefix)(c)
            c = Activation('relu')(c)
            c = DepthwiseConv2D(kernel_size = 3, strides = 1, padding = 'same',
                                use_bias = False, name = '%s/3x3dwconv_3' % prefix)(c)
            c = BatchNormalization(axis = bn_axis, name = '%s/bn_3x3dwconv_3' % prefix)(c)
            c = Conv2D(filters = branch_channels, kernel_size = 1, strides = 1, padding = 'same',
                       use_bias = False, name = '%s/1x1conv_5' % prefix)(c)
            c = BatchNormalization(axis = bn_axis, name = '%s/bn_1x1conv_5' % prefix)(c)
            c = Activation('relu')(c)

            x = Concatenate(axis = bn_axis, name = '%s/concat' % prefix)([c_hat, c])

        x = Lambda(self.channel_shuffle, name = '%s/channel_shuffle' % prefix)(x)

        return x

    def v2_block(self, x, channel_map, repeat = 1, stage = 1):
        x = self._shuffle_unit(x, out_channels = channel_map[stage - 1], strides = 2,
                          stage = stage, block = 1)

        for i in range(1, repeat + 1):
            x = self._shuffle_unit(x, out_channels = channel_map[stage - 1], strides = 1,
                              stage = stage, block = (i + 1))

        return x


    def build(self, include_top = True, input_tensor = None, scale_factor = 1.0, pooling = 'avg',
                      input_shape = (224, 224, 3), num_shuffle_units = [3, 7, 3], weights = None,
                      classes = 1000):
        if K.backend() != 'tensorflow':
            raise RuntimeError('Only TensorFlow backend is currently supported')

        input_shape = _obtain_input_shape(input_shape,
                                          default_size = 224,
                                          min_size = 28,
                                          data_format = K.image_data_format(),
                                          require_flatten = include_top,
                                          weights = weights)

        out_dim_stage_two = {0.5: 48, 1: 116, 1.5: 176, 2: 244}

        if pooling not in ['max', 'avg']:
            raise ValueError('Invalid value for pooling')

        if not (float(scale_factor) * 4).is_integer():
            raise ValueError('Invalid value for scale_factor. Should be x over 4')

        exp = np.insert(np.arange(len(num_shuffle_units), dtype = np.float32), 0, 0)
        out_channels_in_stage = 2 ** exp
        out_channels_in_stage *= out_dim_stage_two[scale_factor]
        out_channels_in_stage[0] = 24
        out_channels_in_stage = out_channels_in_stage.astype(int)

        if input_tensor is None:
            img_input = Input(shape = input_shape)
        else:
            if not K.is_keras_tensor(input_tensor):
                img_input = Input(shape = input_shape, tensor = input_tensor)
            else:
                img_input = input_tensor

        x = Conv2D(filters = out_channels_in_stage[0], kernel_size = 3, strides = 2,
                   padding = 'same', use_bias = False, activation = 'relu', name = 'conv1')(img_input)

        x = MaxPooling2D(pool_size = 3, strides = 2, padding = 'same', name = 'MaxPool1')(x)

        #construct stage2 to 4
        for stage in range(len(num_shuffle_units)):
            repeat = num_shuffle_units[stage]
            x = self.v2_block(x, channel_map = out_channels_in_stage, repeat = repeat, stage = stage + 2)

        #construct final layers
        if scale_factor == 2:
            x = Conv2D(filters = 2048, kernel_size = 1, strides = 1, padding = 'same',
                       use_bias = False, activation = 'relu', name = 'conv5')(x)
        else:
            x = Conv2D(filters = 1024, kernel_size = 1, strides = 1, padding = 'same',
                       use_bias = False, activation = 'relu', name = 'conv5')(x)

        if pooling == 'avg':
            x = GlobalAveragePooling2D(name = 'global_average_pool')(x)
        elif pooling == 'max':
            x = GlobalMaxPooling2D(name = 'global_max_pool')(x)

        if include_top:
            x = Dense(classes, name = 'fc')(x)
            x = Activation('softmax')(x)

        if input_tensor is not None:
            inputs = get_source_inputs(input_tensor)
        else:
            inputs = img_input

        #construct model function
        model = Model(inputs = inputs, outputs = x, name = 'ShuffleNet_V2')

        return model



